import matplotlib.pyplot as plt # type: ignore
import numpy as np # type: ignore
import pandas as pd # type: ignore

def plot_colormaps(freqi = 6000,  tunnel = '', case = ''):
    n = 1000
    
    fig, axarr = plt.subplots(3, 2, sharex='col', figsize = [4.5,5])
    
    for i, model in zip([0,1], ['box', 'slice']):
        print(i, model)
        df = pd.read_csv('comsol_raw/'+model+'_colormap' + tunnel + (case if model == 'slice' else '') + '.csv', skiprows = 8, index_col = [0,1])
        x = np.sort(np.asarray(list(set(df.index.get_level_values(0)))))
        y = np.sort(np.asarray(list(set(df.index.get_level_values(1)))))

        X,Y = np.meshgrid(x,y)
            
        for j, dir in zip(np.arange(3), ['Z','X','Y']):
            print(j, dir)
            ax = axarr[j,i]    
    
            z = np.asarray(df.loc[:, 'solid.uAmp'+dir+' (mm) @ freq='+str(freqi)])*1e6/10 # convert to nm, and then reduce by factor of 10

            Z = np.reshape(z, (n, n))
            
            im = ax.pcolormesh(Y, X, Z, cmap = 'jet', vmin = 0, vmax = 18)

            ax.set_xlim(-0.07, 0.07)
            ax.set_ylim(-0.02, 0.07)
            ax.set_yticks([0,0.05])
            ax.set_xticks([-0.05, 0, 0.05])
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
        
            ax.set_aspect(1)

    for i in np.arange(3):
        axarr[i,1].spines['left'].set_visible(False)
        axarr[i,1].set_yticks([])
    
    fig.subplots_adjust(right=0.85)
    cbar_ax = fig.add_axes([0.88, 0.2, 0.03, 0.6])    
    fig.colorbar(im, cax=cbar_ax)

    # fig.tight_layout()
    plt.savefig('manufigs/colormap' + tunnel + case + '_' + str(freqi)+'Hz.png')
    plt.show()

def plot_colormaps_phase(freqi = 6000, tunnel = '', case = ''):
    n = 1000
    
    fig, axarr = plt.subplots(3, 2, sharex='col', figsize = [4.5,5])
    
    for i, model in zip([0,1], ['box', 'slice']):
        print(i, model)
        df = pd.read_csv('comsol_raw/'+model+'_colormap_phase' + tunnel + (case if model == 'slice' else '') + '.csv', skiprows = 8, index_col = [0,1])
        x = np.sort(np.asarray(list(set(df.index.get_level_values(0)))))
        y = np.sort(np.asarray(list(set(df.index.get_level_values(1)))))
        X,Y = np.meshgrid(x,y)

        phase_z = np.asarray(df.loc[:, 'solid.uPhaseZ (rad) @ freq='+str(freqi)])
        phase_x = np.asarray(df.loc[:, 'solid.uPhaseX (rad) @ freq='+str(freqi)])
        phase_y = np.asarray(df.loc[:, 'solid.uPhaseY (rad) @ freq='+str(freqi)])

        k = 0
        while np.isnan(phase_z[k]):
            k = k + 1
        
        phase_shift = phase_z[k]
        print(k, phase_shift)

        for j, phase in zip(np.arange(3), [phase_z, phase_x, phase_y]):
            print(j)
            ax = axarr[j,i]    

            z = phase - phase_shift

            if j == 0:
                for k in np.arange(len(z)):
                    if z[k] < -1:
                        z[k] = z[k]+2*np.pi
                    if z[k] > np.pi:
                        z[k] = z[k]-2*np.pi
            if j == 1:
                for k in np.arange(len(z)):
                    if z[k] < -1:
                        z[k] = z[k]+2*np.pi
                    if z[k] > 2:
                        z[k] = z[k]-2*np.pi
            
            if j == 2:
                for k in np.arange(len(z)):
                    while z[k] > 1:
                        z[k] = z[k]-2*np.pi

            Z = np.reshape(z, (n, n))/2/np.pi
            
            im = ax.pcolormesh(Y, X, Z, cmap = 'jet', vmin = -1, vmax = 1)
            ax.set_xlim(-0.07, 0.07)
            ax.set_ylim(-0.02, 0.07)
            ax.set_yticks([0,0.05])
            ax.set_xticks([-0.05, 0, 0.05])
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
        
            ax.set_aspect(1)

    for i in np.arange(3):
        axarr[i,1].spines['left'].set_visible(False)
        axarr[i,1].set_yticks([])

    fig.subplots_adjust(right=0.85)
    cbar_ax = fig.add_axes([0.88, 0.2, 0.03, 0.6])
    cbar = fig.colorbar(im, cax=cbar_ax, ticks = [-1, -0.5, 0, 0.5, 1])
    cbar.ax.set_yticklabels(['-1','', '0','','1'])

    # fig.tight_layout()
    plt.savefig('manufigs/colormap' + tunnel + case + '_' + str(freqi)+'Hz_phase.png')
    plt.show()
    

for freqi in [6000]:
    print(freqi)
    case = '_st'
    tunnel = ''
    plot_colormaps(freqi, tunnel, case)
    plot_colormaps_phase(freqi, tunnel, case)
